import tensorflow as tf


class Multi_Agent_NN:
    def __init__(self,
                 config):
        """
        init the model
        """
        # config = LSTMCongfig.load(None)

        self.config = config

        self.positive_state_input_ph = None
        self.reference_state_input_ph = None
        self.negative_state_input_ph = None
        self.reference_state_mask_ph = None
        self.reference_action_seq_ph = None
        self.score_diff_target_ph = None
        self.predict_target_ph = None
        self.player_id_ph = None

        self.positive_trace_lengths_ph = None
        self.reference_trace_lengths_ph = None
        self.negative_trace_lengths_ph = None

        self.lstm_cell_all = []
        self.lstm_score_diff_cell_all = []
        self.lstm_predict_cell_all = []

        self.embed_layer_weights = []
        self.embed_layer_bias = []

        self.identify_layer_weights = []
        self.identify_layer_bias = []

        self.policy_layer_weights = []
        self.policy_layer_bias = []

        self.score_diff_weights = []
        self.score_diff_bias = []

        self.predict_weights = []
        self.predict_bias = []

        self.policy_read_out = None
        self.train_embedding_op = None
        self.build()
        self.initialize_ph()
        self.batch_size = tf.shape(self.positive_state_input_ph)[0]

    def build(self):
        """
        define a shallow dynamic LSTM
        """
        state_feature_number = (self.config.Learn.feature_number - self.config.Learn.action_number)
        # with tf.device('/gpu:0'):
        with tf.variable_scope("LSTM_layer"):
            for i in range(self.config.Arch.Episodic.lstm_layer_num):
                self.lstm_cell_all.append(
                    tf.nn.rnn_cell.LSTMCell(num_units=self.config.Arch.Episodic.h_size, state_is_tuple=True,
                                            initializer=tf.random_uniform_initializer(-0.05, 0.05)))

        with tf.variable_scope("Embed_layer"):
            for i in range(self.config.Arch.Dense.embed_layer_number):
                w_input_size = self.config.Arch.Dense.dense_layer_size if i > 0 else self.config.Arch.Episodic.h_size
                w_output_size = self.config.Arch.Dense.dense_layer_size if i < self.config.Arch.\
                    Dense.embed_layer_number - 1 else self.config.Arch.Episodic.latent_dim
                self.embed_layer_weights.append(tf.get_variable('e_w{0}_xaiver'.format(str(i)),
                                                                [w_input_size, w_output_size],
                                                                initializer=tf.contrib.layers.xavier_initializer()))
                self.embed_layer_bias.append(tf.Variable(tf.zeros([w_output_size]), name="e_b_{0}".format(str(i))))

        with tf.variable_scope("Identify"):
            for i in range(self.config.Arch.Dense.identify_layer_number):
                with tf.variable_scope("Identify_layer_{0}".format(str(i))):
                    w_input_size = self.config.Arch.Episodic.latent_dim if i == 0 else self.config.Arch.Dense.dense_layer_size
                    w_output_size = self.config.Arch.Dense.dense_layer_size if i < self.config.Arch.\
                        Dense.identify_layer_number - 1 else self.config.Learn.player_number
                    self.identify_layer_weights.append(tf.get_variable('i_w{0}_xaiver'.format(str(i)),
                                                                    [w_input_size, w_output_size],
                                                                    initializer=tf.contrib.layers.xavier_initializer()))
                    self.identify_layer_bias.append(tf.Variable(tf.zeros([w_output_size]), name="i_b_{0}".format(str(i))))



        with tf.variable_scope("Policy_layer"):
            self.policy_state_weight = tf.get_variable('p_s_xaiver',[state_feature_number,
                                                                     self.config.Arch.Dense.dense_layer_size],
                                                        initializer=tf.contrib.layers.xavier_initializer())
            self.policy_state_bias = tf.Variable(tf.zeros([self.config.Arch.Dense.dense_layer_size]), name="p_s_b")

            # for j in range(self.config.Learn.max_seq_length):
            #     policy_layer_time_weights = []
            #     policy_layer_time_bias = []
            for i in range(self.config.Arch.Dense.policy_layer_number):
                w_input_size = self.config.Arch.Dense.dense_layer_size if i > 0 else self.config.Arch.Dense.dense_layer_size + self.config.Arch.Episodic.latent_dim
                w_output_size = self.config.Arch.Dense.dense_layer_size if i < self.config.Arch.\
                    Dense.policy_layer_number - 1 else self.config.Learn.action_number
                self.policy_layer_weights.append(tf.get_variable('p_w{0}_xaiver'.format(str(i)),
                                                                [w_input_size, w_output_size],
                                                                initializer=tf.contrib.layers.xavier_initializer()))
                self.policy_layer_bias.append(tf.Variable(tf.zeros([w_output_size]), name="p_b_{0}".format(str(i))))

        w_init = tf.random_normal_initializer(stddev=0.02)
        b_init = tf.constant_initializer(0.)
        validation_input_size = self.config.Arch.Episodic.latent_dim + self.config.Arch.ScoreDiff.h_size

        with tf.variable_scope("score_diff"):


            with tf.variable_scope("score_diff_LSTM"):
                for i in range(self.config.Arch.ScoreDiff.lstm_layer_num):
                    self.lstm_score_diff_cell_all.append(
                        tf.nn.rnn_cell.LSTMCell(num_units=self.config.Arch.ScoreDiff.h_size, state_is_tuple=True,
                                                initializer=tf.random_uniform_initializer(-0.05, 0.05)))

            for i in range(0, self.config.Arch.ScoreDiff.layer_num):
                with tf.variable_scope("Dense_Layer_{0}".format(str(i))):
                    if i == 0:
                        w = tf.get_variable('weight_{0}'.format(str(i)),
                                            [validation_input_size,
                                             self.config.Arch.ScoreDiff.n_hidden],
                                            initializer=w_init)
                    # if i == self.sarsa_hidden_layer_num - 1:
                    else:
                        w = tf.get_variable('weight_{0}'.format(str(i)),
                                            [self.config.Arch.ScoreDiff.n_hidden, self.config.Arch.ScoreDiff.n_hidden],
                                            initializer=w_init)
                    b = tf.get_variable("bias_{0}".format(str(i)), [self.config.Arch.ScoreDiff.n_hidden],
                                        initializer=b_init)
                    self.score_diff_weights.append(w)
                    self.score_diff_bias.append(b)

            with tf.variable_scope("output_Layer"):
                self.score_diff_output_weight = tf.get_variable('weight_out', [self.config.Arch.ScoreDiff.n_hidden,
                                                                               self.config.Arch.ScoreDiff.output_node],
                                                                initializer=w_init)
                self.score_diff_output_bias = tf.get_variable("bias_out", [self.config.Arch.ScoreDiff.output_node],
                                                              initializer=b_init)


        with tf.variable_scope("predict"):

            with tf.variable_scope("predict_LSTM"):
                for i in range(self.config.Arch.Predict.lstm_layer_num):
                    self.lstm_predict_cell_all.append(
                        tf.nn.rnn_cell.LSTMCell(num_units=self.config.Arch.Predict.h_size, state_is_tuple=True,
                                                initializer=tf.random_uniform_initializer(-0.05, 0.05)))

            for i in range(0, self.config.Arch.Predict.layer_num):
                with tf.variable_scope("Dense_Layer_{0}".format(str(i))):
                    if i == 0:
                        w = tf.get_variable('weight_{0}'.format(str(i)),
                                            [validation_input_size,
                                             self.config.Arch.Predict.n_hidden],
                                            initializer=w_init)
                    # if i == self.sarsa_hidden_layer_num - 1:
                    else:
                        w = tf.get_variable('weight_{0}'.format(str(i)),
                                            [self.config.Arch.Predict.n_hidden,
                                             self.config.Arch.Predict.n_hidden],
                                            initializer=w_init)
                    b = tf.get_variable("bias_{0}".format(str(i)), [self.config.Arch.Predict.n_hidden],
                                        initializer=b_init)
                    self.predict_weights.append(w)
                    self.predict_bias.append(b)

            with tf.variable_scope("output_Layer"):
                self.predict_output_weight = tf.get_variable('weight_out', [self.config.Arch.Predict.n_hidden,
                                                                            self.config.Arch.Predict.output_node],
                                                             initializer=w_init)
                self.predict_output_bias = tf.get_variable("bias_out", [self.config.Arch.Predict.output_node],
                                                           initializer=b_init)


    def call(self):
        """
        build the network
        :return:
        """
        state_feature_number = (self.config.Learn.feature_number - self.config.Learn.action_number)
        # with tf.device('/gpu:0'):
        state_input_ph = [self.positive_state_input_ph, self.reference_state_input_ph, self.negative_state_input_ph]

        with tf.variable_scope("Episodic_Embedding_layer"):
            episode_embedding_all = []
            for j in range(3):
                with tf.variable_scope("LSTM_layers_{0}".format(str(j))):
                    rnn_output = None
                    for i in range(self.config.Arch.Episodic.lstm_layer_num):
                        rnn_input = state_input_ph[j] if i == 0 else rnn_output
                        rnn_output, rnn_state = tf.nn.dynamic_rnn(  # while loop dynamic learning rnn
                            inputs=rnn_input, cell=self.lstm_cell_all[i],
                            sequence_length=self.positive_trace_lengths_ph, dtype=tf.float32,
                            scope='rnn_{0}'.format(str(i)))
                    outputs = tf.stack(rnn_output) # Hack to build the indexing and retrieve the right output.
                    # Start indices for each sample
                    self.index = tf.range(0, self.batch_size) * self.config.Learn.max_seq_length + (self.positive_trace_lengths_ph - 1)
                    # Indexing
                    rnn_last = tf.gather(tf.reshape(outputs, [-1, self.config.Arch.Episodic.h_size]), self.index)

                with tf.variable_scope("Embed_layer_{0}".format(str(j))):
                    dense_output = None
                    for i in range(self.config.Arch.Dense.embed_layer_number):
                        dense_input = rnn_last if i == 0 else dense_output
                        # dense_input = embed_layer
                        dense_output = tf.matmul(dense_input, self.embed_layer_weights[i]) + self.embed_layer_bias[i]
                        if i < self.config.Arch.Dense.embed_layer_number - 1:
                            dense_output = tf.nn.relu(dense_output, name='activation_{0}'.format(str(i)))
                    episode_embedding_all.append(dense_output)
            [positive_embedding, reference_embedding, negative_embedding] = episode_embedding_all
            self.episode_embedding = positive_embedding

        with tf.variable_scope('embedding_loss'):
            component_1 = tf.math.sqrt(tf.math.pow(reference_embedding - negative_embedding, 2))
            component_2 = tf.math.sqrt(tf.math.pow(reference_embedding - positive_embedding, 2))

            embedding_loss_all = tf.math.sqrt(1+tf.math.exp(component_1-component_2))
            self.embedding_loss_sum = tf.reduce_sum(embedding_loss_all, axis=1)

        with tf.variable_scope("Identify"):
            identify_output = None
            for i in range(self.config.Arch.Dense.identify_layer_number):
                dense_input = self.episode_embedding if i == 0 else identify_output
                # dense_input = embed_layer
                identify_output = tf.matmul(dense_input, self.identify_layer_weights[i]) + self.identify_layer_bias[i]
                if i < self.config.Arch.Dense.embed_layer_number - 1:
                    identify_output = tf.nn.relu(identify_output, name='activation_{0}'.format(str(i)))

            self.identify_read_out = tf.nn.softmax(identify_output)
            self.identify_lost = tf.losses.softmax_cross_entropy(onehot_labels=self.player_id_ph,
                                                               logits=identify_output,
                                                               reduction='none')

        with tf.variable_scope('Policy_layer'):
            self.policy_lost_all = []

            for j in range(self.config.Learn.max_seq_length):
                dense_state = tf.matmul(self.reference_state_input_ph[:,j,:state_feature_number],
                                        self.policy_state_weight) + self.policy_state_bias
                dense_state_embed = tf.concat([dense_state, self.episode_embedding], axis=1)
                policy_output = None
                for i in range(self.config.Arch.Dense.policy_layer_number):
                    dense_input = dense_state_embed if i == 0 else policy_output
                    policy_output = tf.matmul(dense_input, self.policy_layer_weights[i]) + self.policy_layer_bias[i]
                    if i < self.config.Arch.Dense.policy_layer_number - 1:
                        policy_output = tf.nn.relu(policy_output, name='activation_{1}_{0}'.format(str(i), str(j)))
                self.policy_read_out = tf.nn.softmax(policy_output)

                with tf.variable_scope('cross_entropy_{0}'.format(str(j))):
                    policy_lost_step = tf.losses.softmax_cross_entropy(onehot_labels=self.reference_action_seq_ph[:,j,:],
                                                                       logits=policy_output,
                                                                       reduction='none')
                    self.policy_lost_all.append(policy_lost_step)

        with tf.variable_scope('imitation_loss'):

            policy_lost_all = tf.stack(self.policy_lost_all, axis=1)
            zero_lost_all = tf.zeros(tf.shape(policy_lost_all))
            policy_lost_all = tf.where(condition=self.reference_state_mask_ph, x=policy_lost_all, y=zero_lost_all)
            self.policy_lost_sum = tf.reduce_sum(policy_lost_all, axis=1)

        self.score_diff_output, self.td_score_diff_loss = self.score_diff_value_function(z=self.episode_embedding)

        self.predict_output, self.predict_loss = self.predict_value_function(z=self.episode_embedding)

        # if self.config.Learn.integral_update_flag:
        #     tvars_score_diff = tf.trainable_variables()
        # else:


        with tf.variable_scope("train"):
            total_loss = tf.reduce_mean(self.policy_lost_sum + 0.0001*self.embedding_loss_sum)
            embedding_op = tf.train.AdamOptimizer(learning_rate=self.config.Learn.learning_rate)
            tvars_embed_player = tf.trainable_variables()
            for t in tvars_embed_player:
                print ('tvars_embed_player: ' + str(t.name))
            player_embedding_grads = tf.gradients(total_loss, tvars_embed_player)
            self.train_embedding_op = embedding_op.apply_gradients(zip(player_embedding_grads, tvars_embed_player))

            identify_op = tf.train.AdamOptimizer(learning_rate=self.config.Learn.learning_rate)
            tvars_identify_player = tf.trainable_variables(scope='Identify') # scope='Identify'
            for t in tvars_identify_player:
                print ('tvars_Identify_layer: ' + str(t.name))
            player_identify_grads = tf.gradients(tf.reduce_mean(self.identify_lost), tvars_identify_player)
            self.train_identify_op = identify_op.apply_gradients(zip(player_identify_grads, tvars_identify_player))

            tvars_score_diff = tf.trainable_variables()  # scope='score_diff'
            for t in tvars_score_diff:
                print ('score_diff: ' + str(t.name))
            td_diff_grads = tf.gradients(tf.reduce_mean(self.td_score_diff_loss), tvars_score_diff)
            score_diff_op = tf.train.AdamOptimizer(learning_rate=self.config.Learn.learning_rate)
            self.train_diff_op = score_diff_op.apply_gradients(zip(td_diff_grads, tvars_score_diff))

            tvars_predict = tf.trainable_variables() # scope='predict'
            for t in tvars_predict:
                print ('predict: ' + str(t.name))
            predict_grads = tf.gradients(tf.reduce_mean(self.predict_loss), tvars_predict)
            predict_op = tf.train.AdamOptimizer(learning_rate=self.config.Learn.learning_rate)
            self.train_predict_op = predict_op.apply_gradients(zip(predict_grads, tvars_predict))



    def initialize_ph(self):
        """
        initialize the place holder
        :return:
        """
        positive_state_input_ph = tf.placeholder(dtype=tf.float32, shape=[None, self.config.Learn.max_seq_length,
                                                               self.config.Learn.feature_number],
                                           name="positive_state_input_ph")
        reference_state_input_ph =  tf.placeholder(dtype=tf.float32, shape=[None, self.config.Learn.max_seq_length,
                                                                            self.config.Learn.feature_number],
                                                   name="reference_state_input_ph")
        negative_state_input_ph = tf.placeholder(dtype=tf.float32, shape=[None, self.config.Learn.max_seq_length,
                                                               self.config.Learn.feature_number],
                                           name="negative_state_input_ph")
        positive_trace_lengths_ph = tf.placeholder(dtype=tf.int32, shape=[None], name="positive_trace_lengths_ph")
        reference_trace_lengths_ph = tf.placeholder(dtype=tf.int32, shape=[None], name="reference_trace_lengths_ph")
        negative_trace_lengths_ph = tf.placeholder(dtype=tf.int32, shape=[None], name="negative_trace_lengths_ph")
        reference_action_seq_ph = tf.placeholder(dtype=tf.float32, shape=[None, self.config.Learn.max_seq_length,
                                                                self.config.Learn.action_number],
                                       name = "reference_action_seq_ph")
        player_id_ph = tf.placeholder(dtype=tf.int32, shape=[None, self.config.Learn.player_number], name = "player_id_ph")

        reference_state_mask_ph = tf.cast(tf.placeholder(dtype=tf.int32, shape=[None, self.config.Learn.max_seq_length],
                                                      name="reference_state_mask_ph"), tf.bool)

        score_diff_target_ph = tf.placeholder(dtype=tf.float32, shape=[None, 3], name='score_diff_target')

        predict_target_ph = tf.placeholder(dtype=tf.float32, shape=[None, self.config.Arch.Predict.output_node],
                                                name='predict_target')

        self.score_diff_target_ph = score_diff_target_ph
        self.predict_target_ph = predict_target_ph

        self.positive_state_input_ph = positive_state_input_ph
        self.negative_state_input_ph = negative_state_input_ph
        self.reference_state_input_ph = reference_state_input_ph

        self.positive_trace_lengths_ph = positive_trace_lengths_ph
        self.reference_trace_lengths_ph = reference_trace_lengths_ph
        self.negative_trace_lengths_ph = negative_trace_lengths_ph

        self.reference_action_seq_ph = reference_action_seq_ph
        self.player_id_ph = player_id_ph
        self.reference_state_mask_ph = reference_state_mask_ph

    def score_diff_value_function(self, z):
        with tf.variable_scope("score_diff"):

            with tf.name_scope('diff-lstm-layer'):
                rnn_output = None
                for i in range(self.config.Arch.ScoreDiff.lstm_layer_num):
                    rnn_input = self.positive_state_input_ph if i == 0 else rnn_output
                    rnn_output, rnn_state = tf.nn.dynamic_rnn(  # while loop dynamic learning rnn
                        inputs=rnn_input, cell=self.lstm_score_diff_cell_all[i],
                        sequence_length=self.positive_trace_lengths_ph, dtype=tf.float32,
                        scope='score_diff_rnn_{0}'.format(str(i)))
                outputs = tf.stack(rnn_output)
                # Hack to build the indexing and retrieve the right output.
                # self.batch_size = tf.shape(outputs)[0]
                # Start indices for each sample
                index = tf.range(0, self.batch_size) * self.config.Learn.max_seq_length \
                        + (self.positive_trace_lengths_ph - 1)
                rnn_last = tf.gather(tf.reshape(outputs, [-1, self.config.Arch.ScoreDiff.h_size]), index)
                input_ = rnn_last

            with tf.name_scope('diff-dense-layer'):
                dense_output = None
                for i in range(self.config.Arch.ScoreDiff.layer_num):
                    dense_input = tf.concat([input_, z], axis=1) if i == 0 else dense_output
                    # dense_input = embed_layer
                    dense_output = tf.matmul(dense_input, self.score_diff_weights[i]) + self.score_diff_bias[i]
                    # if i < self.sarsa_hidden_layer_num - 1:
                    dense_output = tf.nn.relu(dense_output, name='activation_{0}'.format(str(i)))

            with tf.name_scope('diff-output-layer'):
                output = tf.matmul(dense_output, self.score_diff_output_weight) + self.score_diff_output_bias

            with tf.name_scope('diff-loss'):
                td_score_diff_loss = tf.reduce_mean(tf.square(output - self.score_diff_target_ph), axis=-1)
                td_score_diff_diff = tf.reduce_mean(tf.abs(output - self.score_diff_target_ph), axis=-1)

        return output, td_score_diff_loss

    def predict_value_function(self, z):
        with tf.variable_scope("predict"):
            with tf.name_scope('predict-lstm-layer'):
                rnn_output = None
                for i in range(self.config.Arch.Predict.lstm_layer_num):
                    rnn_input = self.positive_state_input_ph if i == 0 else rnn_output
                    rnn_output, rnn_state = tf.nn.dynamic_rnn(  # while loop dynamic learning rnn
                        inputs=rnn_input, cell=self.lstm_predict_cell_all[i],
                        sequence_length=self.positive_trace_lengths_ph, dtype=tf.float32,
                        scope='predict_rnn_{0}'.format(str(i)))
                outputs = tf.stack(rnn_output)
                # Hack to build the indexing and retrieve the right output.
                # self.batch_size = tf.shape(outputs)[0]
                # Start indices for each sample
                index = tf.range(0, self.batch_size) * self.config.Learn.max_seq_length \
                        + (self.positive_trace_lengths_ph - 1)
                rnn_last = tf.gather(tf.reshape(outputs, [-1, self.config.Arch.Predict.h_size]), index)
                input_ = rnn_last

            with tf.name_scope('predict-dense-layer'):
                dense_output = None
                for i in range(self.config.Arch.Predict.layer_num):
                    dense_input = tf.concat([input_, z], axis=1) if i == 0 else dense_output
                    # dense_input = embed_layer
                    dense_output = tf.matmul(dense_input, self.predict_weights[i]) + self.predict_bias[i]
                    # if i < self.sarsa_hidden_layer_num - 1:
                    dense_output = tf.nn.relu(dense_output, name='activation_{0}'.format(str(i)))

            with tf.name_scope('predict-output-layer'):
                output = tf.matmul(dense_output, self.predict_output_weight) + self.predict_output_bias

            with tf.name_scope('predict-loss'):
                predict_loss = tf.reduce_mean(tf.losses.softmax_cross_entropy(onehot_labels=self.predict_target_ph,
                                                                               logits=output,
                                                                               reduction=tf.losses.Reduction.NONE))

        return tf.nn.softmax(output), predict_loss
